import numpy as np
import gym
import os
from gym import spaces
from robosuite.controllers import load_controller_config
import imageio
import copy
import cv2
from Record.file_management import numpy_factored, display_frame
import robosuite.utils.macros as macros
from robosuite.wrappers import DataCollectionWrapper, VisualizationWrapper
macros.SIMULATION_TIMESTEP = 0.02

import datetime
import h5py
# import init_path

import libero.libero.envs.bddl_utils as BDDLUtils
from libero.libero.envs.base_object import OBJECTS_DICT
from libero.libero.envs.regions import *
from libero.libero.envs import *
from Environment.environment import Environment
from Environment.Environments.Libero.libero_specs import *

DEFAULT = 0
JOINT_MODE = 1

class LIBEROWorld(Environment):
    def __init__(self, args, variant="default", horizon=30, renderable=False, fixed_limits=False):
        super().__init__()
        self.fixed_limits = fixed_limits
        self.variant=variant
        control_freq, var_horizon, standard_reward, goal_reward, mode = variants[variant]
        horizon = var_horizon if horizon < 0 else horizon
        self.mode = mode
        self.goal_reward = goal_reward
        self.controller = "OSC_POSE" 
        
        # Get controller config
        controller_config = load_controller_config(default_controller=self.controller)

        # Create argument configuration
        config = {
            "robots": args.robots,
            "controller_configs": controller_config,
        }
        
        print("args.robots:", args.robots)

        assert os.path.exists(args.bddl_file)
        problem_info = BDDLUtils.get_problem_info(args.bddl_file)
        # Check if we're using a multi-armed environment and use env_configuration argument if so

        # Create environment
        problem_name = problem_info["problem_name"]
        domain_name = problem_info["domain_name"]
        language_instruction = problem_info["language_instruction"]
        
        print(language_instruction)
        self.env = TASK_MAPPING[problem_name](
            bddl_file_name=args.bddl_file,
            **config,
            has_renderer=renderable,
            has_offscreen_renderer=renderable,
            render_camera=args.camera,
            ignore_done=True,
            use_camera_obs=False,
            reward_shaping=True,
            control_freq=control_freq,
        )
        
        # print("Objects:")
        # print(OBJECTS_DICT.keys())   
        # print("SAMPLER:")
        # print(REGION_SAMPLERS)
        
        # environment properties
        self.name = "libero" # required for an environment 
        self.frameskip = control_freq
        self.timeout_penalty = -horizon

        low, high = self.env.action_spec
        limit = 7 if self.mode == JOINT_MODE else 3
        self.action_shape = (limit,)
        self.action_space = spaces.Box(low=low[:limit], high=high[:limit])
        self.action = np.zeros(self.action_shape)
            
        self.renderable = renderable
        
        # factorized state properties
        self.object_names = self.env.objects_dict  # must be initialized, a list of names that controls the ordering of things
        self.all_names = ["Action", "Gripper", 'Done', "Reward"] + list(self.env.objects_dict)
        self.object_sizes = {"Action": limit, "Gripper": 1, "Done": 1, "Reward": 1} # must be initialized, a dictionary of name to length of the state
        for obj in self.object_names:
            self.object_sizes[obj] = 1
        self.instance_length = len(self.all_names)

        # running values
        self.timer = 0

        # state components
        self.reward = 0
        self.done = False
        self.extracted_state = None

        # factorized state properties
        problem_dict = BDDLUtils.robosuite_parse_problem(args.bddl_file)
        
        # position mask
        self.pos_size = 3

        obs = self.reset()
        self.frame = self.full_state['raw_state'] # the image generated by the environment
        self.reward_collect = 0
        self.observation_space = spaces.Box(low=-1, high=1, shape=[9])

    def construct_full_state(self, factored_state, raw_state):
        self.full_state = {'raw_state': raw_state, 'factored_state': numpy_factored(factored_state)}
        return self.full_state

    def set_action(self, action):
        if self.mode == JOINT_MODE:
            use_act = action
        else:
            use_act = np.concatenate([action, [0, 0, 0, 0]])
        return use_act
    
    def check_contact(self, object1_name, object2_name):
        object_1 = self.env.get_object(object1_name)
        object_2 = self.env.get_object(object2_name)
        return self.env.check_contact(object_1, object_2)

    def check_contacts(self, object1_name):
        objects = self.env.get_object(object1_name)
        return self.env.get_contacts(objects)

    def step(self, action, render=False): # render will NOT change renderable, so it will still render or not render
        # step internal robosuite environment
        self.action = action
        use_act = self.set_action(action)
        next_obs, self.reward, self.done, info = self.env.step(use_act)
        # print("keys:", next_obs.keys())
        self.reward_collect += self.reward
        info["TimeLimit.truncated"] = False
        if self.done:
            info["TimeLimit.truncated"] = True
            print("end of episode:", self.reward_collect)
        if self.reward == self.goal_reward: # don't wait at the goal, just terminate
            print("end of episode:", self.reward_collect)
            self.done = True
            info["TimeLimit.truncated"] = False
       
        img = next_obs["agentview_image"][::-1] if self.renderable else None
        # img = None
        full_state = self.construct_full_state(next_obs, img)
        obs = self.get_state()
        self.frame = self.full_state['raw_state']

        # step timers 
        self.itr += 1
        self.timer += 1

        if self.done:
            self.reset()
            self.timer = 0
        
        return obs, self.reward, self.done, info # obs is full_state

    def get_full_trace(self, factored_state, action, target_name):
        return np.ones(len(self.all_names))
    
    def reset(self):
        obs = self.env.reset()
        self.reward_collect = 0
        self.frame = obs["agentview_image"][::-1] if self.renderable else None
        # self.frame = None
        self.full_state = self.construct_full_state(obs, self.frame)
        return self.get_state()
        
    def render(self):
        return self.frame
    
    def get_state(self, render=False):
        factored_state = self.full_state['factored_state']
        return copy.deepcopy(self.full_state)

    def toString(self, extracted_state):
        estring = "ITR:" + str(self.itr) + "\t"
        for i, obj in enumerate(self.objects):
            if obj not in ["Reward", "Done"]:
                estring += obj + ":" + " ".join(map(str, extracted_state[obj])) + "\t" # TODO: attributes are limited to single floats
            else:
                estring += obj + ":" + str(int(extracted_state[obj][0])) + "\t"
        return estring